#!/usr/bin/env python3

# used to generate model: attention_qk_output_0.onnx

import numpy as np
import onnx
import onnx.helper
from onnx import TensorProto
from onnx.reference import ReferenceEvaluator


def build_model():
    # Define the graph inputs and outputs
    q = onnx.helper.make_tensor_value_info("q", TensorProto.FLOAT, [1, 1, 2, 2])
    k = onnx.helper.make_tensor_value_info("k", TensorProto.FLOAT, [1, 1, 1, 2])
    v = onnx.helper.make_tensor_value_info("v", TensorProto.FLOAT, [1, 1, 1, 2])
    attn_mask = onnx.helper.make_tensor_value_info(
        "attn_mask", TensorProto.BOOL, [2, 2]
    )
    past_k = onnx.helper.make_tensor_value_info(
        "past_k", TensorProto.FLOAT, [1, 1, 1, 2]
    )
    past_v = onnx.helper.make_tensor_value_info(
        "past_v", TensorProto.FLOAT, [1, 1, 1, 2]
    )
    y = onnx.helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 1, 2, 2])
    present_k = onnx.helper.make_tensor_value_info(
        "present_k", TensorProto.FLOAT, [1, 1, 2, 2]
    )
    present_v = onnx.helper.make_tensor_value_info(
        "present_v", TensorProto.FLOAT, [1, 1, 2, 2]
    )
    qk_matmul_output = onnx.helper.make_tensor_value_info(
        "qk_matmul_output", TensorProto.FLOAT, [1, 1, 2, 2]
    )

    # Create the GroupNormalization node
    attention = onnx.helper.make_node(
        "Attention",
        inputs=["q", "k", "v", "attn_mask", "past_k", "past_v"],
        outputs=["y", "present_k", "present_v", "qk_matmul_output"],
        name="AttentionNode",
        softcap=2.0,
        qk_matmul_output_mode=0,
    )

    # Create the graph
    graph = onnx.helper.make_graph(
        [attention],
        "AttentionModel",
        [q, k, v, attn_mask, past_k, past_v],
        [y, present_k, present_v, qk_matmul_output],
    )

    # Create the model
    model = onnx.helper.make_model(
        opset_imports=[onnx.helper.make_operatorsetid("", 23)],
        graph=graph,
        producer_name="ONNX_Generator",
    )

    return model


if __name__ == "__main__":
    # Set seed and precision
    np.random.seed(42)
    np.set_printoptions(precision=8)

    # Build model
    q = np.array([[[[1.0, 0.0], [0.0, 1.0]]]])
    k = np.array([[[[1.0, 0.0]]]])
    v = np.array([[[[0.3, 0.6]]]])
    past_k = np.array([[[[0.0, 1.0]]]])
    past_v = np.array([[[[0.25, 0.5]]]])
    onnx_model = build_model()
    file_name = "attention_qk_output_0.onnx"

    # Ensure valid ONNX and save
    onnx.checker.check_model(onnx_model)
    onnx.save(onnx_model, file_name)
    print(f"Finished exporting model to {file_name}")

    # Output some test data for use in the test
    print(
        f"Test input data: {repr(q)} {repr(k)} {repr(v)} {repr(past_k)} {repr(past_v)}"
    )
    print(
        f"Test input data shape: {q.shape} {k.shape} {v.shape} {past_k.shape} {past_v.shape}"
    )
    session = ReferenceEvaluator("attention_qk_output_0.onnx", verbose=1)
    (test_output, present_k, present_v, qk_matmul_output) = session.run(
        None,
        {"q": q, "k": k, "v": v, "attn_mask": None, "past_k": past_k, "past_v": past_v},
    )
    print(
        f"Test output: {repr(test_output)} {repr(present_k)} {repr(present_v)} {repr(qk_matmul_output)}"
    )
    print(
        f"Test output shape: {test_output.shape} {present_k.shape} {present_v.shape} {qk_matmul_output.shape}"
    )
